Skip to content

Classify questions #6663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .env-test
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ REUSE_DB=0
ENABLE_ADMIN=True
SET_LOCALE_PATH=False
SECURE_SSL_REDIRECT=False
GOOGLE_APPLICATION_CREDENTIALS=creds
GOOGLE_CLOUD_PROJECT=sumo-test
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,4 @@ chromedriver.log
.cursor-server
.cursor
.gk/
gcloud/
15 changes: 0 additions & 15 deletions kitsune/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +0,0 @@
from functools import cache

from langchain_google_vertexai import ChatVertexAI


@cache
def get_llm(
model_name: str, temperature: int = 1, max_tokens: int | None = None, max_retries: int = 2
) -> ChatVertexAI:
"""
Returns a LangChain chat model instance based on the given LLM model name.
"""
return ChatVertexAI(
model=model_name, temperature=temperature, max_tokens=max_tokens, max_retries=max_retries
)
72 changes: 72 additions & 0 deletions kitsune/llm/questions/classifiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import TYPE_CHECKING, Any

from django.db import models
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough

from kitsune.llm.questions.prompt import spam_parser, spam_prompt, topic_parser, topic_prompt
from kitsune.llm.utils import get_llm
from kitsune.products.utils import get_taxonomy

DEFAULT_LLM_MODEL = "gemini-2.5-flash-preview-04-17"
HIGH_CONFIDENCE_THRESHOLD = 75
LOW_CONFIDENCE_THRESHOLD = 60

if TYPE_CHECKING:
from kitsune.questions.models import Question


class ModerationAction(models.TextChoices):
NOT_SPAM = "not_spam", "Not Spam"
SPAM = "spam", "Spam"
FLAG_REVIEW = "flag_review", "Flag for Review"


def classify_question(question: "Question") -> dict[str, Any]:
"""
Analyze a question for spam and, if not spam or low confidence, classify the topic.
Returns a dict with keys: action, spam_result, topic_result (optional).
"""
llm = get_llm(model_name=DEFAULT_LLM_MODEL)

product = question.product
payload: dict[str, Any] = {
"question": question.content,
"product": product,
"topics": get_taxonomy(
product, include_metadata=["description", "examples"], output_format="JSON"
),
}

spam_detection_chain = spam_prompt | llm | spam_parser
topic_classification_chain = topic_prompt | llm | topic_parser

def decision_lambda(payload: dict[str, Any]) -> dict[str, Any]:
spam_result: dict[str, Any] = payload["spam_result"]
confidence: int = spam_result.get("confidence", 0)
is_spam: bool = spam_result.get("is_spam", False)
result = {
"action": ModerationAction.NOT_SPAM,
"spam_result": spam_result,
"topic_result": {},
}

if is_spam:
match confidence:
case _ if confidence >= HIGH_CONFIDENCE_THRESHOLD:
result["action"] = ModerationAction.SPAM
case _ if (
confidence > LOW_CONFIDENCE_THRESHOLD
and confidence < HIGH_CONFIDENCE_THRESHOLD
):
result["action"] = ModerationAction.FLAG_REVIEW

if result["action"] == ModerationAction.NOT_SPAM:
result["topic_result"] = topic_classification_chain.invoke(payload)

return result

pipeline = RunnablePassthrough.assign(spam_result=spam_detection_chain) | RunnableLambda(
decision_lambda
)
result: dict[str, Any] = pipeline.invoke(payload)
return result
33 changes: 33 additions & 0 deletions kitsune/llm/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import waffle
from celery import shared_task

from kitsune.llm.questions.classifiers import classify_question
from kitsune.users.models import Profile

shared_task_with_retry = shared_task(
acks_late=True, autoretry_for=(Exception,), retry_backoff=2, retry_kwargs=dict(max_retries=3)
)


@shared_task_with_retry
def question_classifier(question_id):
from kitsune.questions.models import Question
from kitsune.questions.utils import flag_question, process_classification_result

try:
question = Question.objects.get(id=question_id)
except Question.DoesNotExist:
return

if waffle.switch_is_active("auto-question-classifier"):
result = classify_question(question)
process_classification_result(question, result)
elif waffle.switch_is_active("flagit-spam-autoflag"):
flag_question(
question,
by_user=Profile.get_sumo_bot(),
notes=(
"Automatically flagged for topic moderation:"
" auto-question-classifier is disabled"
),
)
18 changes: 18 additions & 0 deletions kitsune/llm/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from functools import cache

from langchain_google_vertexai import ChatVertexAI


@cache
def get_llm(
model_name: str,
temperature: int = 1,
max_tokens: int | None = None,
max_retries: int = 2,
) -> ChatVertexAI:
"""
Returns a LangChain chat model instance based on the given LLM model name.
"""
return ChatVertexAI(
model=model_name, temperature=temperature, max_tokens=max_tokens, max_retries=max_retries
)
15 changes: 3 additions & 12 deletions kitsune/questions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import actstream
import actstream.actions
import waffle
from django.conf import settings
from django.contrib.auth.models import User
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
Expand All @@ -25,6 +24,7 @@
from product_details import product_details

from kitsune.flagit.models import FlaggedObject
from kitsune.llm.tasks import question_classifier
from kitsune.products.models import Product, Topic
from kitsune.questions import config
from kitsune.questions.managers import AAQConfigManager, AnswerManager, QuestionManager
Expand Down Expand Up @@ -203,17 +203,8 @@ def save(self, update=False, *args, **kwargs):
# actstream
# Authors should automatically follow their own questions.
actstream.actions.follow(self.creator, self, send_action=False, actor_only=False)
if waffle.switch_is_active("flagit-spam-autoflag"):
# Add the question to the moderation queue to validate the topic
content_type = ContentType.objects.get_for_model(self)
FlaggedObject.objects.create(
content_type=content_type,
object_id=self.id,
creator=self.creator,
status=FlaggedObject.FLAG_PENDING,
reason=FlaggedObject.REASON_CONTENT_MODERATION,
notes="New question, review topic",
)
# Either automatically classify the question or add it to the moderation queue
question_classifier.delay(self.id)

def add_metadata(self, **kwargs):
"""Add (save to db) the passed in metadata.
Expand Down
2 changes: 1 addition & 1 deletion kitsune/questions/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,6 @@ def test_question_solved_makes_action(self):
@mock.patch.object(waffle, "switch_is_active")
def test_create_question_creates_flag(self, switch_is_active):
"""Creating a question also creates a flag."""
switch_is_active.return_value = True
switch_is_active.side_effect = lambda name: name == "flagit-spam-autoflag"
QuestionFactory(title="Test Question", content="Lorem Ipsum Dolor")
self.assertEqual(1, FlaggedObject.objects.filter(reason="content_moderation").count())
71 changes: 68 additions & 3 deletions kitsune/questions/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import json
import logging
import re
from typing import Optional
from typing import Any, Optional

from django.conf import settings
from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType
from django.contrib.sessions.backends.base import SessionBase

from kitsune.flagit.models import FlaggedObject
from kitsune.llm.questions.classifiers import ModerationAction
from kitsune.products.models import Product, Topic
from kitsune.questions.models import Answer, Question
from kitsune.wiki.utils import get_featured_articles as kb_get_featured_articles, has_visited_kb

from kitsune.users.models import Profile
from kitsune.wiki.utils import get_featured_articles as kb_get_featured_articles
from kitsune.wiki.utils import has_visited_kb

REGEX_NON_WINDOWS_HOME_DIR = re.compile(
r"(?P<home_dir_parent>/(?:user|users|home)/)[^/]+", re.IGNORECASE
Expand Down Expand Up @@ -138,3 +143,63 @@ def get_ga_submit_event_parameters_as_json(
data["topics"] = f"/{topic.slug}/"

return json.dumps(data)


def flag_question(
question: Question,
by_user: User,
notes: str,
status: int = FlaggedObject.FLAG_PENDING,
reason: str = FlaggedObject.REASON_CONTENT_MODERATION,
) -> None:
content_type = ContentType.objects.get_for_model(question)
FlaggedObject.objects.create(
content_type=content_type,
object_id=question.id,
creator=by_user,
status=status,
reason=reason,
notes=notes,
)


def process_classification_result(
question: Question,
result: dict[str, Any],
) -> None:
"""
Process the classification result from the LLM and take moderation action.
"""
sumo_bot = Profile.get_sumo_bot()
action = result.get("action")
match action:
case ModerationAction.SPAM:
question.mark_as_spam(sumo_bot)
case ModerationAction.FLAG_REVIEW:
flag_question(
question,
by_user=sumo_bot,
notes=(
"LLM flagged for manual review, for the following reason:\n"
f"{result['spam_result']['reason']}"
),
reason=FlaggedObject.REASON_SPAM,
)
case _:
if topic_title := result["topic_result"].get("topic"):
try:
topic = Topic.objects.get(title=topic_title)
except (Topic.DoesNotExist, Topic.MultipleObjectsReturned):
return
else:
flag_question(
question,
by_user=sumo_bot,
notes=(
"LLM classified as {topic.title}, for the following reason:\n"
f"{result['topic_result']['reason']}"
),
status=FlaggedObject.FLAG_ACCEPTED,
)
question.topic = topic
question.save()
3 changes: 3 additions & 0 deletions kitsune/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,3 +1335,6 @@ def filter_exceptions(event, hint):
SUMO_CONTENT_GROUP = config("SUMO_CONTENT_GROUP", default="Staff Content Team")

USER_INACTIVITY_DAYS = config("USER_INACTIVITY_DAYS", default=1095, cast=int)

GOOGLE_APPLICATION_CREDENTIALS = config("GOOGLE_APPLICATION_CREDENTIALS", default="")
GOOGLE_CLOUD_PROJECT = config("GOOGLE_CLOUD_PROJECT", default="")